package com.auditlog.sql.runner;

import cn.hutool.core.collection.CollectionUtil;
import com.auditlog.datasource.ParametersHolder;
import com.auditlog.datasource.StatementProxy;
import com.auditlog.datasource.db.SqlType;
import com.auditlog.datasource.struct.*;
import com.auditlog.exception.NotSupportException;
import com.auditlog.util.IndexUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.insert.Insert;

import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;


/**
 * @author Zhiyang.Zhang
 * @version 1.0
 * @date 2022/9/23 0:14
 */
@Slf4j
public class InsertUpdateSqlRunner<T extends Statement> extends AbstractInsertSqlRunner<T> {

    public InsertUpdateSqlRunner(StatementProxy<T> statementProxy, Insert insert) {
        super(statementProxy, insert);
    }

    private List<Integer> pkIndexList = new ArrayList<>();
    private List<String> pkNameList = new ArrayList<>();
    private List<List<Integer>> ukIndexList = new ArrayList<>();
    private List<List<String>> ukNameList = new ArrayList<>();
    private List<List<Object>> filteredColumnsValue = new ArrayList<>();
    private List<List<Object>> columnsValue = new ArrayList<>();
    private int keysAllNullRowCount;
    private List<String> allColumns;

    @Override
    public void validate() {
        // 对于预编译的insert on duplicate key支持batch操作
        if (this.isBatch() && !(this.getStatementProxy() instanceof ParametersHolder)) {
            throw new NotSupportException("insert on duplicate key不能和多条批处理sql同时使用");
        }
    }

    @Override
    public RecordImage afterImage(RecordImage beforeRecordImage) throws Exception {
        if (keysAllNullRowCount > 0) {
            ResultSet generatedKeys = this.getStatementProxy().getGeneratedKeys();
            ResultSetMetaData metaData = generatedKeys.getMetaData();
            int columnCount = metaData.getColumnCount();
            List<List<Object>> pkValues = new ArrayList<>();
            while (generatedKeys.next()) {
                List<Object> ids = new ArrayList<>();
                for (int i = 1; i <= columnCount; i++) {
                    // 这里产生的key有可能不是真正的插入key（当插入的key数据库已存在时）
                    ids.add(generatedKeys.getObject(i));
                }
                pkValues.add(ids);
            }
            if (CollectionUtil.isNotEmpty(pkValues)) {
                TableRecord tableRecord = this.getTableRecord(this.getTableNameWithAlias(), this.pkNameList, pkValues, this.allColumns);
                beforeRecordImage.setAfterRecord(tableRecord);
                beforeRecordImage.getInsertPks().addAll(pkValues);
            }
            return beforeRecordImage;
        }
        ExistTableRow existTableRow = this.tableRows(this.pkIndexList, this.pkNameList, this.ukIndexList, this.ukNameList, this.allColumns, this.columnsValue);
        TableRecord afterRecord = TableRecord.builder().delimiter(this.getDelimiter()).rows(existTableRow.getRows()).build();
        beforeRecordImage.setAfterRecord(afterRecord);
        existTableRow.getPkValues().forEach(pkValue -> this.populatePks(beforeRecordImage, pkValue));
        return beforeRecordImage;
    }

    @Override
    public RecordImage beforeImage() throws SQLException {
        pkNameList = this.getTableMeta().getPrimaryKeyWithAlias(this.getAlias(), true);
        allColumns = this.getTableMeta().getColumnNamesByOrder(this.getAlias());
        Map<String, List<String>> uniqueKeyMap = this.getTableMeta().getUniqueKeyMap(this.getAlias(), true);
        // 更新列的值（暂时用不上，取数据库的）
        List<Expression> duplicateUpdateExpressionList = this.getStatement().getDuplicateUpdateExpressionList();
        // 更新的列
        List<Column> duplicateUpdateColumns = this.getStatement().getDuplicateUpdateColumns();
        // 转换为列名集合
        List<String> duplicateColumnsName = columnsName(duplicateUpdateColumns);
        this.columnsValue = this.extractInsertValues();
        if (columnsValue.size() == 0) {
            // 没有要插入和更新的数据返回
            return RecordImage.builder().
                    tableName(this.getTableName()).alias(this.getAlias())
                    .updateColumnsName(duplicateColumnsName)
                    .allColumnsWithAlias(this.getTableMeta().getColumnNamesByOrder(this.getAlias()))
                    .tableMeta(this.getTableMeta())
                    .pkColumnsName(pkNameList)
                    .build();
        }
        List<String> insertColumns = this.extractInsertColumns();
        return buildRecordImage(insertColumns, uniqueKeyMap);
    }


    public RecordImage buildRecordImage(List<String> columns, Map<String, List<String>> uniqueKeyMap) throws SQLException {
        pkIndexList = IndexUtils.getIndexList(pkNameList, columns);
        Map<String, List<Integer>> ukIndexMap = IndexUtils.getIndexList(uniqueKeyMap, columns);
        List<TupleTwo<List<String>, List<Integer>>> tupleTwos = new ArrayList<>();
        if (pkIndexList.size() > 0) {
            TupleTwo<List<String>, List<Integer>> tupleTwo = TupleTwo.<List<String>, List<Integer>>builder()
                    .v1(pkNameList).v2(pkIndexList)
                    .build();
            tupleTwos.add(tupleTwo);
        }
        ukNameList = new ArrayList<>();
        ukIndexList = new ArrayList<>();
        // 过滤掉唯一键值重复的数据
        for (Map.Entry<String, List<Integer>> entry : ukIndexMap.entrySet()) {
            String key = entry.getKey();
            List<String> names = uniqueKeyMap.get(key);
            List<Integer> indexes = entry.getValue();
            // 唯一键的位置集合
            ukNameList.add(names);
            ukIndexList.add(indexes);
            TupleTwo<List<String>, List<Integer>> tupleTwo = TupleTwo.<List<String>, List<Integer>>builder()
                    .v1(names).v2(indexes)
                    .build();
            tupleTwos.add(tupleTwo);
        }
        // 为什么这里不用过滤后的列值，为了防止一些key被过滤后，造成没有锁住，其他sql线程可以插入或者更新了
        // filteredColumnsValue = filter(tupleTwos, columnsValue);
        // 根据插入键值查询数据库存在的键值
        // 进行这里的查询作用是：
        //  1.为了分辨数据库中：待插入的两条数据唯一键不同但是只对应数据库的一条数据，比如：数据1：指定了唯一键1，数据2：指定了唯一键2，但是唯一键1和唯一键2对应的是相同的数据
        ExistTableRow existTableRow = tableRows(pkIndexList, pkNameList, ukIndexList, ukNameList,
                this.allColumns, this.columnsValue);
        int insertCount = getInsertCount(tupleTwos, this.columnsValue, existTableRow.getAllKeys());
        this.keysAllNullRowCount = getKeysAllNullRowCount(tupleTwos, this.columnsValue);
        if (this.keysAllNullRowCount > 0 && this.columnsValue.size() > 1) {
            // 因为on duplicate key update中，mysqlclient在获取产生的key时，只获取一个
            throw new IllegalArgumentException("insert on duplicate key只支持一条数据的插入,实际插入" + insertCount + "条");
        }
        if (this.keysAllNullRowCount > 0 && !this.getTableMeta().isAutoIncreasePk()) {
            throw new IllegalArgumentException("表[" + this.getTableName() + "]不是自增主键");
        }
        RecordImage recordImage = RecordImage.builder()
                .possibleUpdateOrDeletePks(existTableRow.getPkValues())
                .pkColumnsName(existTableRow.getPkColumns())
                .allColumnsWithAlias(this.allColumns)
                .tableMeta(this.getTableMeta())
                .alias(this.getAlias())
                .tableName(this.getTableName())
                .sqlType(SqlType.INSERT_DUPLICATE)
                .build();
        recordImage.getBeforeRecord().addRows(existTableRow.getRows());
        recordImage.getBeforeRecord().setDelimiter(this.getDelimiter());
        return recordImage;
    }


    /**
     * 获取主键列和被更新列和对应的值<br>
     * 1.利用待插入数据对应主键和唯一键值去数据库查询是否存在<br>
     * 2.将这些唯一键和主键分别组成key（这是用来判断插入条数的）返回<br>
     * 3.主键值返回<br></br>
     */
    public ExistTableRow tableRows(List<Integer> pkColumnsIndex,
                                   List<String> pkColumns,
                                   List<List<Integer>> ukColumnsIndex,
                                   List<List<String>> ukNameList,
                                   List<String> allColumns,
                                   List<List<Object>> columnsValue) throws SQLException {
        ExistTableRow.ExistTableRowBuilder columnMetaBuilder = ExistTableRow.builder().pkColumns(pkColumns);
        if (CollectionUtil.isEmpty(pkColumnsIndex) && CollectionUtil.isEmpty(ukColumnsIndex)) {
            return columnMetaBuilder.build();
        }

        // 主键列值
        TupleTwo<List<String>, List<List<Object>>> pkTuple = TupleTwo.<List<String>, List<List<Object>>>builder().build();
        if (pkColumnsIndex.size() > 0) {
            List<List<Object>> pkValues = getIndexValue4List(columnsValue, pkColumnsIndex, false);
            pkTuple = TupleTwo.<List<String>, List<List<Object>>>builder()
                    .v1(pkColumns).v2(pkValues).build();
        }

        List<TupleTwo<List<String>, List<List<Object>>>> tupleTwos = new ArrayList<>();
        int curr = 0;
        for (List<Integer> columnsIndex : ukColumnsIndex) {
            List<List<Object>> indexValueList = getIndexValue4List(columnsValue, columnsIndex, false);
            if (indexValueList.size() > 0) {
                TupleTwo<List<String>, List<List<Object>>> tupleTwo = TupleTwo.<List<String>, List<List<Object>>>builder()
                        .v1(ukNameList.get(curr)).v2(indexValueList).build();
                tupleTwos.add(tupleTwo);
            }
            curr++;
        }
        if (!pkTuple.isEmpty()) {
            tupleTwos.add(0, pkTuple);
        }
        // 拼装where子句，每个主键或唯一键组成一个子句，全部用or连接
        TupleTwo<String, List<Object>> whereSql = assembleWhere(tupleTwos);
        if (tupleTwos.size() == 0) {
            // 全部数据都可以插入
            return columnMetaBuilder.build();
        }
        // 查询主键
        String pkSelect = pkColumns.stream().collect(Collectors.joining(",", " select ", ","));
        StringBuilder builder = new StringBuilder(pkSelect);
        // 运行到这里就一定>0
        int outCycle = tupleTwos.size();
        for (int out = 0; out < outCycle; out++) {
            TupleTwo<List<String>, List<List<Object>>> tupleTwo = tupleTwos.get(out);
            int in = 0;
            for (String column : tupleTwo.getV1()) {
                builder.append(column);
                if (in != tupleTwo.getV1().size() - 1) {
                    builder.append(" , ");
                }
                in++;
            }
            if (out != outCycle - 1) {
                builder.append(" , ");
            }
        }
        String selectAllColumns = allColumns.stream().collect(Collectors.joining(",", " ", " "));
        builder.append(",").append(selectAllColumns)
                .append(" from ").append(this.getTableNameWithAlias())
                .append(" where ").append(whereSql.getV1());
        List<Object> parameters = whereSql.getV2();
        ResultSet resultSet = this.executeSimpleSql(builder.toString(), parameters, true);
        int columnCount = resultSet.getMetaData().getColumnCount();
        Set<String> keys = new HashSet<>();
        List<List<Object>> pkValueList = new ArrayList<>();
        Map<String, List<Object>> rows = new HashMap<>();
        while (resultSet.next()) {
            // 获取主键值
            int size = pkColumns.size();
            List<Object> pkValues = new ArrayList<>();
            for (int s = 1; s <= size; s++) {
                pkValues.add(resultSet.getObject(s));
            }
            String pkKey = this.assembleKey(pkColumns, pkValues);
            pkValueList.add(pkValues);
            // key列
            int currentIndex = size;
            for (TupleTwo<List<String>, List<List<Object>>> tupleTwo : tupleTwos) {
                List<String> v1 = tupleTwo.getV1();
                List<Object> values = new ArrayList<>();
                boolean skip = false;
                for (int j = 1; j <= v1.size(); j++) {
                    currentIndex++;
                    if (currentIndex > columnCount) {
                        throw new IllegalStateException("索引下标出现异常");
                    }
                    Object v = resultSet.getObject(currentIndex);
                    if (v == null) {
                        skip = true;
                    }
                    values.add(v);
                }
                if (!skip) {
                    String key = assembleKey(v1, values);
                    // 将返回的键值存起来，用来与插入时的键值比较，找出哪些是插入的数据
                    keys.add(key);
                }
            }
            List<Object> rowData = new ArrayList<>();
            for (String allColumn : allColumns) {
                currentIndex++;
                rowData.add(resultSet.getObject(currentIndex));
            }
            rows.put(pkKey, rowData);
        }
        return columnMetaBuilder.allKeys(keys).pkValues(pkValueList).rows(rows).build();
    }

    /**
     * 列合并
     *
     * @param target
     * @param columns
     * @param exclude 是否排除相同列
     * @return: void
     */
    private void addColumnNames(List<String> target, List<String> columns, boolean exclude) {
        if (CollectionUtil.isNotEmpty(columns)) {
            for (String column : columns) {
                if (target.contains(column) && !exclude || !target.contains(column)) {
                    target.add(column);
                }
            }
        }
    }

    /**
     * 过滤掉键值相同的行数据<br>
     * 过滤待插入数据中相同的数据，如果键值中某个值为空就忽略该键值（判断下一个键值），
     * 如果前面的数据已经包含了该键值，那么这条数就被过滤掉（<b>每行每行的比较，不是每列每列比较</b>）<br>
     *
     * @param tupleTwos
     * @param columnsValue
     * @return: void
     */
    private List<List<Object>> filter(List<TupleTwo<List<String>, List<Integer>>> tupleTwos, List<List<Object>> columnsValue) {
        Set<String> keys = new HashSet<>();
        int index = 0;
        boolean exit = false;
        int removeIndex = -1;
        for (List<Object> values : columnsValue) {
            for (TupleTwo<List<String>, List<Integer>> tupleTwo : tupleTwos) {
                List<String> v1 = tupleTwo.getV1();
                List<Integer> v2 = tupleTwo.getV2();
                List<Object> indexValueList = indexValue4List(values, v2, false);
                if (indexValueList.size() > 0) {
                    String key = assembleKey(v1, indexValueList);
                    if (keys.contains(key)) {
                        exit = true; //只要出现了，就跳出循环把那一条数据清理掉
                        removeIndex = index;
                        break;
                    }
                    keys.add(key);
                }
            }
            if (exit) {
                break;
            }
            index++;
        }
        if (removeIndex != -1) {
            columnsValue.get(removeIndex).clear();
            List<List<Object>> temp = columnsValue.stream().filter(v -> v.size() > 0).collect(Collectors.toList());
            return filter(tupleTwos, temp);
        } else {
            return columnsValue;
        }
    }

    /**
     * 获取插入的数据条数
     *
     * @param tupleTwos
     * @param columnsValue
     * @param keys
     * @return: int
     */
    private int getInsertCount(List<TupleTwo<List<String>, List<Integer>>> tupleTwos, List<List<Object>> columnsValue, Set<String> keys) {
        int count = 0;
        for (List<Object> values : columnsValue) {
            boolean exist = false;
            // 每一行只要有一个键值能查询到数据，则表示为更新
            for (TupleTwo<List<String>, List<Integer>> tupleTwo : tupleTwos) {
                List<String> v1 = tupleTwo.getV1();
                List<Integer> v2 = tupleTwo.getV2();
                List<Object> indexValueList = indexValue4List(values, v2, false);
                if (indexValueList.size() > 0) {
                    String key = assembleKey(v1, indexValueList);
                    if (keys.contains(key)) {
                        exist = true;
                        break;
                    }
                } else {
                    // 包含空值，则能插入
                }
            }
            if (!exist) {
                count++;
            }
        }
        return count;
    }


    /**
     * 没有指定主键和唯一键的行数
     *
     * @param tupleTwos
     * @param columnsValue
     * @return: int
     */
    private int getKeysAllNullRowCount(List<TupleTwo<List<String>, List<Integer>>> tupleTwos, List<List<Object>> columnsValue) {
        int count = 0;
        for (List<Object> values : columnsValue) {
            boolean allNull = true;
            // 每一行只要有一个键值能查询到数据，则表示为更新
            for (TupleTwo<List<String>, List<Integer>> tupleTwo : tupleTwos) {
                List<Integer> v2 = tupleTwo.getV2();
                List<Object> indexValueList = indexValue4List(values, v2, false);
                if (CollectionUtil.isNotEmpty(indexValueList)) {
                    allNull = false;
                    break;
                }
            }
            if (allNull) {
                count++;
            }
        }
        return count;
    }

    /**
     * 将所有的key组装成sql
     *
     * @param tupleTwos
     * @return: com.auditlog.datasource.struct.TupleTwo<java.lang.String, java.util.List < java.lang.Object>>
     */
    public TupleTwo<String, List<Object>> assembleWhere(List<TupleTwo<List<String>, List<List<Object>>>> tupleTwos) {
        int cycle = tupleTwos.size();
        StringBuilder builder = new StringBuilder();
        List<Object> allValues = new ArrayList<>();
        for (int c = 0; c < cycle; c++) {
            TupleTwo<List<String>, List<List<Object>>> tupleTwo = tupleTwos.get(c);
            String subSql = assembleWhere(tupleTwo.getV1(), tupleTwo.getV2());
            if (c == 0) {
                builder.append(subSql);
            } else {
                builder.append(" or ").append(subSql);
            }
            for (List<Object> objects : tupleTwo.getV2()) {
                allValues.addAll(objects);
            }
        }
        return TupleTwo.<String, List<Object>>builder().v1(builder.toString()).v2(allValues).build();
    }

    @Override
    public StatementMetaData getMetaData(Insert statement) {
        return StatementMetaData.builder()
                .sqlRunner(this)
                .sqlType(SqlType.INSERT_DUPLICATE)
                .tableName(this.getTableName())
                .build();
    }
}
